import numpy as np
import gymnasium as gym
from tianshou.data.batch import Batch

class Extractor:
    """
    Handles factorizaton of state
    """

    def __init__(self, env: gym.Env):
        self.dict_obs_space = env.dict_obs_space
        self.breakpoints = env.breakpoints                          # includes 0, (num_factors + 1)
        self.goal_based = env.goal_based
        self.num_factors = env.num_factors
        self.env = env
        

        self.factor_names = list(self.dict_obs_space.keys())

        # get the longest segment
        self.longest = np.max(self.breakpoints[1:] - self.breakpoints[:-1])
        self.obs_mask = np.zeros((self.num_factors, self.longest), dtype=bool)
        for i, (lbp, rbp) in enumerate(zip(self.breakpoints[:-1], self.breakpoints[1:])):
            self.obs_mask[i, :rbp - lbp] = True
    
    def get_achieved_goal_state(self, state, fidx=None):
        return self.env.get_achieved_goal_state(state, fidx=fidx) if hasattr(self.env, "get_achieved_goal_state") else state

    def slice_targets(self, batch, next=False, append_id = False, append_act=False, append_rew_done = False, flatten=False):
        """
        takes an observation of state of length n, and breaks it into segments at the breakpoints
        padded to the longest segment length
        """
        obs = (batch.obs_next if next else batch.obs) if type(batch) == Batch else batch  # allows array inputs
        observation = obs.observation if type(obs) == Batch else obs
        bs = observation.shape[:-1]
        target = np.zeros((bs + (self.num_factors, self.longest)), dtype=observation.dtype)
        target[..., self.obs_mask] = observation
        # append the action if append_act
        if append_act:
            act = np.zeros(bs + (self.longest, ))
            if type(batch) == Batch and "act" in batch: act[..., :batch.act.shape[-1]] = batch.act
            target = np.concatenate([np.expand_dims(act, axis=1), target], axis=1)
        
        # append the reward and done if append_rew_done
        if append_rew_done:
            rew = np.zeros(bs + (self.longest, ))
            if type(batch) == Batch and "rew" in batch: rew[..., :1] = np.expand_dims(batch.rew, axis=-1)

            done = np.zeros(bs + (self.longest, ))
            if type(batch) == Batch and "done" in batch: done[..., :1] = np.expand_dims(batch.done, axis=-1)
            target = np.concatenate([target, np.expand_dims(rew, axis=1), np.expand_dims(done, axis=1)], axis=1)


        if append_id:
            num_factors_added = self.num_factors + int(append_act) + 2 * int(append_rew_done) 
            id = np.tile(np.eye(num_factors_added), bs + (num_factors_added, num_factors_added)) # TODO: IDs are probably wrong if multiple instances of the same class
            target = np.concatenate((target, id), axis=-1)
        if flatten:
            if len(bs) > 0: return target.reshape(bs[0], -1)
            return target.flatten()
        else: return target
    
def compute_proximity(targets):
    # returns the object distances between every object, as a batch size x nxn matrix
    dists = list()
    for i in range(targets.shape[1]):
        target = np.expand_dims(targets[:, i], axis=1)
        dists.append(np.linalg.norm(targets - target, axis=-1))
    return np.stack(dists, axis=1)
